{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finetune ALXLNET-Bahasa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "This tutorial is available as an IPython notebook at [Malaya/finetune/alxlnet](https://github.com/huseinzol05/Malaya/tree/master/finetune/alxlnet).\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, I will going to show to finetune pretrained ALXLNET-Bahasa using Tensorflow Estimator.\n",
    "\n",
    "TF-Estimator is really a great module created by Tensorflow Team to train a model for a very long period."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install tensorflow==1.15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download pretrained model\n",
    "\n",
    "https://github.com/huseinzol05/Malaya/tree/master/pretrained-model/alxlnet#download, In this example, we are going to try BASE size. Just uncomment below to download pretrained model and tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__pycache__\t\t\t modeling.py\r\n",
      "alxlnet-base\t\t\t prepro_utils.py\r\n",
      "alxlnet-base-500k-20-10-2020.gz  sp10m.cased.v9.model\r\n",
      "alxlnet-base_config.json\t tf-estimator-text-classification.ipynb\r\n",
      "custom_modeling.py\t\t xlnet.py\r\n",
      "model_utils.py\r\n"
     ]
    }
   ],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/alxlnet-base-500k-20-10-2020.gz\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/preprocess/sp10m.cased.v9.model\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/alxlnet/config/alxlnet-base_config.json\n",
    "# !tar -zxf alxlnet-base-500k-20-10-2020.gz\n",
    "!ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.ckpt-500000.data-00000-of-00001  model.ckpt-500000.meta\r\n",
      "model.ckpt-500000.index\r\n"
     ]
    }
   ],
   "source": [
    "!ls alxlnet-base"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is a helper function [malaya/finetune/utils.py](https://github.com/huseinzol05/Malaya/blob/master/finetune/utils.py) to help us to train the model on single GPU or multiGPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, '../')\n",
    "import utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just going to train on very small news bahasa sentiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Lebih-lebih lagi dengan  kemudahan internet da...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Positive</td>\n",
       "      <td>boleh memberi teguran kepada parti tetapi perl...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Adalah membingungkan mengapa masyarakat Cina b...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Positive</td>\n",
       "      <td>Kami menurunkan defisit daripada 6.7 peratus p...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Ini masalahnya. Bukan rakyat, tetapi sistem</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                               text\n",
       "0  Negative  Lebih-lebih lagi dengan  kemudahan internet da...\n",
       "1  Positive  boleh memberi teguran kepada parti tetapi perl...\n",
       "2  Negative  Adalah membingungkan mengapa masyarakat Cina b...\n",
       "3  Positive  Kami menurunkan defisit daripada 6.7 peratus p...\n",
       "4  Negative        Ini masalahnya. Bukan rakyat, tetapi sistem"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('../sentiment-data-v2.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Negative', 'Positive']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = df['label'].values.tolist()\n",
    "texts = df['text'].values.tolist()\n",
    "unique_labels = sorted(list(set(labels)))\n",
    "unique_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/model_utils.py:334: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import xlnet\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import model_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.v9.model')\n",
    "\n",
    "SEG_ID_A = 0\n",
    "SEG_ID_B = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    '<unk>': 0,\n",
    "    '<s>': 1,\n",
    "    '</s>': 2,\n",
    "    '<cls>': 3,\n",
    "    '<sep>': 4,\n",
    "    '<pad>': 5,\n",
    "    '<mask>': 6,\n",
    "    '<eod>': 7,\n",
    "    '<eop>': 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols['<unk>']\n",
    "CLS_ID = special_symbols['<cls>']\n",
    "SEP_ID = special_symbols['<sep>']\n",
    "MASK_ID = special_symbols['<mask>']\n",
    "EOD_ID = special_symbols['<eod>']\n",
    "\n",
    "\n",
    "def tokenize_fn(text):\n",
    "    text = preprocess_text(text, lower = False)\n",
    "    return encode_ids(sp_model, text)\n",
    "\n",
    "\n",
    "def token_to_ids(text, maxlen = 512):\n",
    "    tokens_a = tokenize_fn(text)\n",
    "    if len(tokens_a) > maxlen - 2:\n",
    "        tokens_a = tokens_a[: (maxlen - 2)]\n",
    "    segment_id = [SEG_ID_A] * len(tokens_a)\n",
    "    tokens_a.append(SEP_ID)\n",
    "    tokens_a.append(CLS_ID)\n",
    "    segment_id.append(SEG_ID_A)\n",
    "    segment_id.append(SEG_ID_CLS)\n",
    "    input_mask = [0.0] * len(tokens_a)\n",
    "    assert len(tokens_a) == len(input_mask) == len(segment_id)\n",
    "    return {\n",
    "        'input_id': tokens_a,\n",
    "        'input_mask': input_mask,\n",
    "        'segment_id': segment_id,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. `input_id`, integer representation of tokenized words, sorted based on sentencepiece weightage.\n",
    "2. `input_mask`, attention masking. During training, short words will padded with `1`, so we do not want the model learn padded values as part of the context. https://github.com/zihangdai/xlnet/blob/master/classifier_utils.py#L113\n",
    "3. `segment_id`, Use for text pair classification, in this case, we can simply put `0`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': [1620,\n",
       "  13,\n",
       "  5177,\n",
       "  53,\n",
       "  33,\n",
       "  2808,\n",
       "  3168,\n",
       "  24,\n",
       "  3400,\n",
       "  807,\n",
       "  21,\n",
       "  16179,\n",
       "  31,\n",
       "  742,\n",
       "  578,\n",
       "  17153,\n",
       "  9,\n",
       "  4,\n",
       "  3],\n",
       " 'input_mask': [0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0],\n",
       " 'segment_id': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_to_ids(texts[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TF-Estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TF-Estimator, required 2 parts,\n",
    "\n",
    "1. Input pipeline, https://www.tensorflow.org/api_docs/python/tf/data/Dataset\n",
    "2. Model definition, https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate():\n",
    "    while True:\n",
    "        for i in range(len(texts)):\n",
    "            if len(texts[i]) > 5:\n",
    "                d = token_to_ids(texts[i])\n",
    "                d['label'] = [unique_labels.index(labels[i])]\n",
    "                d.pop('tokens', None)\n",
    "                yield d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': [1620,\n",
       "  13,\n",
       "  5177,\n",
       "  53,\n",
       "  33,\n",
       "  2808,\n",
       "  3168,\n",
       "  24,\n",
       "  3400,\n",
       "  807,\n",
       "  21,\n",
       "  16179,\n",
       "  31,\n",
       "  742,\n",
       "  578,\n",
       "  17153,\n",
       "  9,\n",
       "  4,\n",
       "  3],\n",
       " 'input_mask': [0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0,\n",
       "  0.0],\n",
       " 'segment_id': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
       " 'label': [0]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = generate()\n",
    "next(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It must a function return a function.\n",
    "\n",
    "```python\n",
    "def get_dataset(batch_size = 32, shuffle_size = 32):\n",
    "    def get():\n",
    "        return dataset\n",
    "    return get\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(batch_size = 32, shuffle_size = 32):\n",
    "    def get():\n",
    "        dataset = tf.data.Dataset.from_generator(\n",
    "            generate,\n",
    "            {'input_id': tf.int32, 'input_mask': tf.float32, 'segment_id': tf.int32, 'label': tf.int32},\n",
    "            output_shapes = {\n",
    "                'input_id': tf.TensorShape([None]),\n",
    "                'input_mask': tf.TensorShape([None]),\n",
    "                'segment_id': tf.TensorShape([None]),\n",
    "                'label': tf.TensorShape([None])\n",
    "            },\n",
    "        )\n",
    "        dataset = dataset.shuffle(shuffle_size)\n",
    "        dataset = dataset.padded_batch(\n",
    "            batch_size,\n",
    "            padded_shapes = {\n",
    "                'input_id': tf.TensorShape([None]),\n",
    "                'input_mask': tf.TensorShape([None]),\n",
    "                'segment_id': tf.TensorShape([None]),\n",
    "                'label': tf.TensorShape([None])\n",
    "            },\n",
    "            padding_values = {\n",
    "                'input_id': tf.constant(0, dtype = tf.int32),\n",
    "                'input_mask': tf.constant(1.0, dtype = tf.float32),\n",
    "                'segment_id': tf.constant(4, dtype = tf.int32),\n",
    "                'label': tf.constant(0, dtype = tf.int32),\n",
    "            },\n",
    "        )\n",
    "        return dataset\n",
    "    return get"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test data pipeline using tf.session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-13-2f00f4f10c26>:4: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "iterator = get_dataset()()\n",
    "iterator = iterator.make_one_shot_iterator().get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': <tf.Tensor 'IteratorGetNext:0' shape=(?, ?) dtype=int32>,\n",
       " 'input_mask': <tf.Tensor 'IteratorGetNext:1' shape=(?, ?) dtype=float32>,\n",
       " 'segment_id': <tf.Tensor 'IteratorGetNext:3' shape=(?, ?) dtype=int32>,\n",
       " 'label': <tf.Tensor 'IteratorGetNext:2' shape=(?, ?) dtype=int32>}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_id': array([[  19, 4084, 1500, ...,    0,    0,    0],\n",
       "        [2740, 9369,   31, ...,    0,    0,    0],\n",
       "        [1084,  791,  835, ...,    0,    0,    0],\n",
       "        ...,\n",
       "        [ 767,  250,   51, ...,    0,    0,    0],\n",
       "        [3593,   21, 7901, ...,    0,    0,    0],\n",
       "        [8097, 2519,  271, ...,    0,    0,    0]], dtype=int32),\n",
       " 'input_mask': array([[0., 0., 0., ..., 1., 1., 1.],\n",
       "        [0., 0., 0., ..., 1., 1., 1.],\n",
       "        [0., 0., 0., ..., 1., 1., 1.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 1., 1., 1.],\n",
       "        [0., 0., 0., ..., 1., 1., 1.],\n",
       "        [0., 0., 0., ..., 1., 1., 1.]], dtype=float32),\n",
       " 'segment_id': array([[0, 0, 0, ..., 4, 4, 4],\n",
       "        [0, 0, 0, ..., 4, 4, 4],\n",
       "        [0, 0, 0, ..., 4, 4, 4],\n",
       "        ...,\n",
       "        [0, 0, 0, ..., 4, 4, 4],\n",
       "        [0, 0, 0, ..., 4, 4, 4],\n",
       "        [0, 0, 0, ..., 4, 4, 4]], dtype=int32),\n",
       " 'label': array([[0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [0],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [0],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1],\n",
       "        [1]], dtype=int32)}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It must a function accepts 4 parameters.\n",
    "\n",
    "```python\n",
    "def model_fn(features, labels, mode, params):\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/xlnet.py:70: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kwargs = dict(\n",
    "    is_training = True,\n",
    "    use_tpu = False,\n",
    "    use_bfloat16 = False,\n",
    "    dropout = 0.1,\n",
    "    dropatt = 0.1,\n",
    "    init = 'normal',\n",
    "    init_range = 0.1,\n",
    "    init_std = 0.05,\n",
    "    clamp_len = -1,\n",
    ")\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path = 'alxlnet-base_config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 32\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = 10\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "training_parameters = dict(\n",
    "    decay_method = 'poly',\n",
    "    train_steps = num_train_steps,\n",
    "    learning_rate = learning_rate,\n",
    "    warmup_steps = num_warmup_steps,\n",
    "    min_lr_ratio = 0.0,\n",
    "    weight_decay = 0.00,\n",
    "    adam_epsilon = 1e-8,\n",
    "    num_core_per_host = 1,\n",
    "    lr_layer_decay_rate = 1,\n",
    "    use_tpu = False,\n",
    "    use_bfloat16 = False,\n",
    "    dropout = 0.0,\n",
    "    dropatt = 0.0,\n",
    "    init = 'normal',\n",
    "    init_range = 0.1,\n",
    "    init_std = 0.05,\n",
    "    clip = 1.0,\n",
    "    clamp_len = -1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter:\n",
    "    def __init__(\n",
    "        self,\n",
    "        decay_method,\n",
    "        warmup_steps,\n",
    "        weight_decay,\n",
    "        adam_epsilon,\n",
    "        num_core_per_host,\n",
    "        lr_layer_decay_rate,\n",
    "        use_tpu,\n",
    "        learning_rate,\n",
    "        train_steps,\n",
    "        min_lr_ratio,\n",
    "        clip,\n",
    "        **kwargs\n",
    "    ):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "\n",
    "\n",
    "training_parameters = Parameter(**training_parameters)\n",
    "init_checkpoint = 'alxlnet-base/model.ckpt-500000'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_fn(features, labels, mode, params):\n",
    "    Y = tf.cast(features['label'][:, 0], tf.int32)\n",
    "\n",
    "    xlnet_model = xlnet.XLNetModel(\n",
    "        xlnet_config = xlnet_config,\n",
    "        run_config = xlnet_parameters,\n",
    "        input_ids = tf.transpose(features['input_id'], [1, 0]),\n",
    "        seg_ids = tf.transpose(features['segment_id'], [1, 0]),\n",
    "        input_mask = tf.transpose(features['input_mask'], [1, 0]),\n",
    "    )\n",
    "\n",
    "    output_layer = xlnet_model.get_sequence_output()\n",
    "    output_layer = tf.transpose(output_layer, [1, 0, 2])\n",
    "\n",
    "    logits_seq = tf.layers.dense(output_layer, 2)\n",
    "    logits = logits_seq[:, 0]\n",
    "\n",
    "    loss = tf.reduce_mean(\n",
    "        tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits = logits, labels = Y\n",
    "        )\n",
    "    )\n",
    "\n",
    "    tf.identity(loss, 'train_loss')\n",
    "\n",
    "    accuracy = tf.metrics.accuracy(\n",
    "        labels = Y, predictions = tf.argmax(logits, axis = 1)\n",
    "    )\n",
    "    tf.identity(accuracy[1], name = 'train_accuracy')\n",
    "\n",
    "    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n",
    "\n",
    "    assignment_map, initialized_variable_names = utils.get_assignment_map_from_checkpoint(\n",
    "        variables, init_checkpoint\n",
    "    )\n",
    "\n",
    "    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)\n",
    "\n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        train_op, _, _ = model_utils.get_train_op(training_parameters, loss)\n",
    "        estimator_spec = tf.estimator.EstimatorSpec(\n",
    "            mode = mode, loss = loss, train_op = train_op\n",
    "        )\n",
    "\n",
    "    elif mode == tf.estimator.ModeKeys.EVAL:\n",
    "        estimator_spec = tf.estimator.EstimatorSpec(\n",
    "            mode = tf.estimator.ModeKeys.EVAL,\n",
    "            loss = loss,\n",
    "            eval_metric_ops = {'accuracy': accuracy},\n",
    "        )\n",
    "\n",
    "    return estimator_spec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initiate training session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = get_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From ../utils.py:62: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n",
      "\n",
      "WARNING:tensorflow:From ../utils.py:62: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n",
      "\n",
      "INFO:tensorflow:Using config: {'_model_dir': 'finetuned-alxlnet-base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 10, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 1, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f7a5c1b2e48>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/xlnet.py:253: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/xlnet.py:253: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/custom_modeling.py:696: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/custom_modeling.py:703: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/custom_modeling.py:808: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/ubuntu/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/custom_modeling.py:109: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "INFO:tensorflow:**** Trainable Variables ****\n",
      "INFO:tensorflow:  name = model/transformer/r_w_bias:0, shape = (12, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/r_r_bias:0, shape = (12, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/word_embedding/lookup_table:0, shape = (32000, 128), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/word_embedding/lookup_table_2:0, shape = (128, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/r_s_bias:0, shape = (12, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/seg_embed:0, shape = (12, 2, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/rel_attn/q/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/rel_attn/k/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/rel_attn/v/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/rel_attn/r/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/rel_attn/o/kernel:0, shape = (768, 12, 64), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/rel_attn/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/rel_attn/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/ff/layer_1/kernel:0, shape = (768, 3072), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/ff/layer_1/bias:0, shape = (3072,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/ff/layer_2/kernel:0, shape = (3072, 768), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/ff/layer_2/bias:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/ff/LayerNorm/beta:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = model/transformer/layer_shared/ff/LayerNorm/gamma:0, shape = (768,), *INIT_FROM_CKPT*\n",
      "INFO:tensorflow:  name = dense/kernel:0, shape = (768, 2)\n",
      "INFO:tensorflow:  name = dense/bias:0, shape = (2,)\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/model_utils.py:105: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/model_utils.py:119: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/model_utils.py:136: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/ubuntu/malay/Malaya/finetune/alxlnet/model_utils.py:150: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 0 into finetuned-alxlnet-base/model.ckpt.\n",
      "INFO:tensorflow:train_accuracy = 0.34375, train_loss = 1.0174254\n",
      "INFO:tensorflow:loss = 1.0174254, step = 1\n",
      "INFO:tensorflow:global_step/sec: 0.0483137\n",
      "INFO:tensorflow:train_accuracy = 0.4375, train_loss = 0.7347818 (20.699 sec)\n",
      "INFO:tensorflow:loss = 0.7347818, step = 2 (20.698 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.0969523\n",
      "INFO:tensorflow:train_accuracy = 0.4375, train_loss = 0.9868502 (10.314 sec)\n",
      "INFO:tensorflow:loss = 0.9868502, step = 3 (10.315 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.139668\n",
      "INFO:tensorflow:train_accuracy = 0.453125, train_loss = 0.7655573 (7.159 sec)\n",
      "INFO:tensorflow:loss = 0.7655573, step = 4 (7.159 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.16895\n",
      "INFO:tensorflow:train_accuracy = 0.48125, train_loss = 0.7466663 (5.919 sec)\n",
      "INFO:tensorflow:loss = 0.7466663, step = 5 (5.920 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.131785\n",
      "INFO:tensorflow:train_accuracy = 0.44270834, train_loss = 0.94823694 (7.588 sec)\n",
      "INFO:tensorflow:loss = 0.94823694, step = 6 (7.588 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.15756\n",
      "INFO:tensorflow:train_accuracy = 0.42410713, train_loss = 0.8999996 (6.347 sec)\n",
      "INFO:tensorflow:loss = 0.8999996, step = 7 (6.346 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.144356\n",
      "INFO:tensorflow:train_accuracy = 0.41796875, train_loss = 0.92889994 (6.927 sec)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:loss = 0.92889994, step = 8 (6.927 sec)\n",
      "INFO:tensorflow:global_step/sec: 0.121323\n",
      "INFO:tensorflow:train_accuracy = 0.41666666, train_loss = 1.0723866 (8.242 sec)\n",
      "INFO:tensorflow:loss = 1.0723866, step = 9 (8.242 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 10 into finetuned-alxlnet-base/model.ckpt.\n",
      "INFO:tensorflow:global_step/sec: 0.130973\n",
      "INFO:tensorflow:train_accuracy = 0.409375, train_loss = 1.0876663 (7.635 sec)\n",
      "INFO:tensorflow:loss = 1.0876663, step = 10 (7.636 sec)\n",
      "INFO:tensorflow:Loss for final step: 1.0876663.\n"
     ]
    }
   ],
   "source": [
    "train_hooks = [\n",
    "    tf.train.LoggingTensorHook(\n",
    "        ['train_accuracy', 'train_loss'], every_n_iter = 1\n",
    "    )\n",
    "]\n",
    "utils.run_training(\n",
    "    train_fn = train_dataset,\n",
    "    model_fn = model_fn,\n",
    "    model_dir = 'finetuned-alxlnet-base',\n",
    "    num_gpus = 1,\n",
    "    log_step = 1,\n",
    "    save_checkpoint_step = epoch,\n",
    "    max_steps = epoch,\n",
    "    train_hooks = train_hooks,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
